#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date    : 2019-01-27 17:08:22
# @Author  : jimmy (jimmywangheng@qq.com)
# @Link    : http://sdcs.sysu.edu.cn
# @Version : $Id$

import os
from copy import deepcopy
import pickle
import math

import numpy as np
import collections
from itertools import count
from sklearn.metrics.pairwise import cosine_similarity
import sys
import random
import time
import datetime

from networks import policy_nn_lstm_attn
from utils import *
from env import Env

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from BFS.KB import *

import tensorflow as tf
from keras.backend.tensorflow_backend import set_session  
config = tf.ConfigProto()  
config.gpu_options.allow_growth = True  
set_session(tf.Session(config=config))

from sklearn import linear_model
from keras.models import Sequential 
from keras.layers import Dense, Activation

from ConvE import ConvE_double

import argparse
argparser = argparse.ArgumentParser()
argparser.add_argument('-a', '--attn_dim', type=int, default=100)
argparser.add_argument('-wr', '--wrong_reward', type=float, default=-0.05)
argparser.add_argument('-ur', '--useless_reward', type=float, default=0.01) 
argparser.add_argument('-tr', '--teacher_reward', type=float, default=1.0)
argparser.add_argument('-gw', '--global_reward_weight', type=float, default=0.1)
argparser.add_argument('-lw', '--length_reward_weight', type=float, default=0.8)
argparser.add_argument('-dw', '--diverse_reward_weight', type=float, default=0.1)
argparser.add_argument('-np', '--num_episodes', type=int, default=500)
argparser.add_argument('-wd', '--weight_decay', type=float, default=0.005)
argparser.add_argument('-d2', '--dropout2', type=float, default=0.3)
argparser.add_argument('-adr', '--action_dropout_rate', type=float, default=0.3)
argparser.add_argument('-eb', '--exp_base', type=float, default=math.e)
argparser.add_argument('-r', '--relation', type=str, default='concept_agentbelongstoorganization')
argparser.add_argument('-t', '--task', type=str, default='retrain')
argparser.add_argument('-hue', '--hidden_update_everytime', type=int, default=0)
argparser.add_argument('-mo', '--model', type=str, default="TransD")
argparser.add_argument('-remo', '--reward_shaping_model', type=str, default="ConvE")


args = argparser.parse_args()
print("Relation: ", args.relation)

time_str = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')

save_file_header = '_'.join([
	time_str,
	'a', str(args.attn_dim),
	'wr', str(args.wrong_reward),
	'ur', str(args.useless_reward),
	'tr', str(args.teacher_reward),
	'gw', str(args.global_reward_weight),
	'lw', str(args.length_reward_weight),
	'dw', str(args.diverse_reward_weight),
	'np', str(args.num_episodes),
	'wd', str(args.weight_decay),
	'd2', str(args.dropout2),
	'adr', str(args.action_dropout_rate),
	'eb', "%.4f" % args.exp_base,
	'hue', str(args.hidden_update_everytime),
	'mo', str(args.model),
	'remo', str(args.reward_shaping_model),
	'no_pretraining_dynamic_attention_lstm_reward_shaping_select_path_three_regularization_methods',]) 

USE_CUDA = torch.cuda.is_available()
if USE_CUDA:
	longTensor = torch.cuda.LongTensor
	floatTensor = torch.cuda.FloatTensor
	byteTensor = torch.cuda.ByteTensor

else:
	longTensor = torch.LongTensor
	floatTensor = torch.FloatTensor
	byteTensor = torch.ByteTensor

relation = args.relation
task = args.task

graphpath = dataPath + 'tasks/' + relation + '/' + 'graph.txt'
relationPath = dataPath + 'tasks/' + relation + '/' + 'train_pos'

dataPath_ = '../NELL-995/tasks/'  + relation + '/'
filename = save_file_header + '_path_stats.txt'
feature_stats = dataPath_ + 'path_stats/' + filename
relationId_path = '../NELL-995/relation2id.txt'
ent_id_path = '../NELL-995/' + 'entity2id.txt'
rel_id_path = '../NELL-995/' + 'relation2id.txt'
test_data_path = '../NELL-995/tasks/'  + relation + '/sort_test.pairs'

selectionPath = dataPath + 'selection_proportion.txt'
replacementPath = dataPath + 'replacement_proportion.txt'

selection_dict = {}
with open(selectionPath) as fr:
	for line in fr:
		line_list = line.strip().split()
		selection_dict[line_list[0]] = float(line_list[1])

replacement_dict = {}
with open(replacementPath) as fr:
	for line in fr:
		line_list = line.strip().split()
		replacement_dict[line_list[0]] = float(line_list[1])

dynamic_weight_decay = args.weight_decay * (args.exp_base ** (selection_dict[relation] + replacement_dict[relation]))
if dynamic_weight_decay > 1:
	dynamic_weight_decay = 1

dynamic_action_dropout_rate = args.action_dropout_rate * (args.exp_base ** (selection_dict[relation] + replacement_dict[relation]))
if dynamic_action_dropout_rate > 0.5:
	dynamic_action_dropout_rate = 0.5

dynamic_dropout_rate2 = args.dropout2 * (args.exp_base ** (selection_dict[relation] + replacement_dict[relation]))
if dynamic_dropout_rate2 > 0.5:
	dynamic_dropout_rate2 = 0.5

# Which model is used to compute reward shaping?
if args.reward_shaping_model == "TransH":
	print("Uses TransH as reward shaping model.")
	ent_embedding_path = '../NELL-995/NELL-995_100_1.0_TransH_entity_embedding.txt'
	rel_embedding_path = '../NELL-995/NELL-995_100_1.0_TransH_relation_embedding.txt'
	norm_embedding_path = '../NELL-995/NELL-995_100_1.0_TransH_norm_embedding.txt'
	ent_embedding = np.loadtxt(ent_embedding_path)
	rel_embedding = np.loadtxt(rel_embedding_path)
	norm_embedding = np.loadtxt(norm_embedding_path)

elif args.reward_shaping_model == "TransR":
	print("Uses TransR as reward shaping model.")
	ent_embedding_path = '../NELL-995/NELL-995_100_1.0_TransR_entity_embedding.txt'
	rel_embedding_path = '../NELL-995/NELL-995_100_1.0_TransR_relation_embedding.txt'
	norm_embedding_path = '../NELL-995/NELL-995_100_1.0_TransR_norm_embedding.txt'
	ent_embedding = np.loadtxt(ent_embedding_path)
	rel_embedding = np.loadtxt(rel_embedding_path)
	norm_embedding = np.loadtxt(norm_embedding_path)
	norm_embedding = norm_embedding.reshape([-1, 100, 100])

elif args.reward_shaping_model == "TransD":
	print("Uses TransD as reward shaping model.")
	ent_embedding_path = '../NELL-995/NELL-995_100_1.0_TransD_entity_embedding.txt'
	rel_embedding_path = '../NELL-995/NELL-995_100_1.0_TransD_relation_embedding.txt'
	ent_norm_embedding_path = '../NELL-995/NELL-995_100_1.0_TransD_ent_norm_embedding.txt'
	rel_norm_embedding_path = '../NELL-995/NELL-995_100_1.0_TransD_rel_norm_embedding.txt'
	ent_embedding = np.loadtxt(ent_embedding_path)
	rel_embedding = np.loadtxt(rel_embedding_path)
	ent_norm_embedding = np.loadtxt(ent_norm_embedding_path)
	rel_norm_embedding = np.loadtxt(rel_norm_embedding_path)

elif args.reward_shaping_model == "ProjE":
	print("Uses ProjE as reward shaping model.")
	ent_embedding_path = '../NELL-995/NELL-995_100_ProjE_entity_embedding.txt'
	rel_embedding_path = '../NELL-995/NELL-995_100_ProjE_relation_embedding.txt'
	simple_hr_combination_weights_path = '../NELL-995/NELL-995_100_ProjE_simple_hr_combination_weights.txt'
	simple_tr_combination_weights_path = '../NELL-995/NELL-995_100_ProjE_simple_tr_combination_weights.txt'
	combination_bias_hr_path = '../NELL-995/NELL-995_100_ProjE_combination_bias_hr.txt'
	combination_bias_tr_path = '../NELL-995/NELL-995_100_ProjE_combination_bias_tr.txt'

	ent_embedding = np.loadtxt(ent_embedding_path)
	rel_embedding = np.loadtxt(rel_embedding_path)
	simple_hr_combination_weights = np.loadtxt(simple_hr_combination_weights_path)
	simple_tr_combination_weights = np.loadtxt(simple_tr_combination_weights_path)
	combination_bias_hr = np.loadtxt(combination_bias_hr_path)
	combination_bias_tr = np.loadtxt(combination_bias_tr_path)

elif args.reward_shaping_model == "ConvE":
	print("Uses ConvE as reward shaping model.")
	TransE_to_ConvE_id_entity = {}
	with open("../NELL-995/TransE_to_ConvE_entity_id.txt") as fr:
		for line in fr:
			line_list = line.strip().split()
			TransE_to_ConvE_id_entity[int(line_list[0])] = int(line_list[1])

	TransE_to_ConvE_id_relation = {}
	with open("../NELL-995/TransE_to_ConvE_relation_id.txt") as fr:
		for line in fr:
			line_list = line.strip().split()
			TransE_to_ConvE_id_relation[int(line_list[0])] = int(line_list[1])

	homepath = os.path.expanduser('~')
	token2idx_ent, idx2token_ent, label2idx_ent, idx2label_ent = pickle.load(open(homepath + "/.data/NELL-995/vocab_e1", 'rb'))
	token2idx_rel, idx2token_rel, label2idx_rel, idx2label_rel = pickle.load(open(homepath + "/.data/NELL-995/vocab_rel", 'rb'))
	ConvE_model = ConvE_double(len(token2idx_ent), len(token2idx_rel))
	model_params = torch.load("../NELL-995/NELL-995_ConvE_0.2_0.3_100.model")
	ConvE_model.load_state_dict(model_params)

	for parameter in ConvE_model.parameters():
		parameter.requires_grad = False

	if USE_CUDA:
		ConvE_model.cuda()

	torch.cuda.empty_cache()

else:
	print("Default. Uses TransE as reward shaping model.")
	ent_embedding_path = '../NELL-995/NELL-995_100_1.0_TransE_entity_embedding.txt'
	rel_embedding_path = '../NELL-995/NELL-995_100_1.0_TransE_relation_embedding.txt'
	ent_embedding = np.loadtxt(ent_embedding_path)
	rel_embedding = np.loadtxt(rel_embedding_path)

def convert_to_one_hot(y, depth):
	return np.eye(depth)[y.reshape(-1)]

def REINFORCE(training_pairs, policy_nn, optimizer, num_episodes, relation=None):
	f = open(graphpath)
	content = f.readlines()
	f.close()
	kb = KB()
	for line in content:
		ent1, rel, ent2 = line.rsplit()
		kb.addRelation(ent1, rel, ent2) # Each line is a triple, represented with strings instead of numbers
		
	dropout = nn.Dropout(dynamic_action_dropout_rate)

	train = training_pairs

	success = 0

	path_found = set()
	path_found_entity = []
	path_relation_found = []
	success_cnt_list = []

	env = Env(dataPath, train[0], model=args.model)
	# Initialize the environment

	for i_episode in range(num_episodes):
	# for i_episode in range(15):
		start = time.time()
		print ('Episode %d' % i_episode)
		sample = train[random.choice(range(len(training_pairs)))]
		print ('Training sample: ', sample[:-1])

		if relation is None:
			env = Env(dataPath, sample, args.model)
		else:
			env.path = []
			env.path_relations = []

		sample = sample.split()
		state_idx = [env.entity2id_[sample[0]], env.entity2id_[sample[1]], 0]

		episode = []

		state_batch_negative = []
		lstm_input_batch_negative = []
		hidden_batch_negative = []
		cell_batch_negative = []
		action_batch_negative = []
		now_embedding_batch_negative = []
		neighbour_embeddings_list_batch_negative = []

		state_batch_positive = []
		lstm_input_batch_positive = []
		hidden_batch_positive = []
		cell_batch_positive = []
		action_batch_positive = []
		now_embedding_batch_positive = []
		neighbour_embeddings_list_batch_positive = []

		hidden_this_time = torch.zeros(3, 1, hidden_dim)
		cell_this_time = torch.zeros(3, 1, hidden_dim)
		if USE_CUDA:
			hidden_this_time = hidden_this_time.cuda()
			cell_this_time = cell_this_time.cuda()

		forward_node_list = []

		for t in count():
		# for t in range(10):
			state_vec = floatTensor(env.idx_state(state_idx))
			state = torch.cat([state_vec, hidden_this_time[-1]], dim=1) # Only use the last layer's output
			lstm_input = state_vec.unsqueeze(1)

			now_embedding = floatTensor(env.entity2vec[[state_idx[0]]])

			connected_node_list = []
			if state_idx[0] in env.entity2link:
				for rel in env.entity2link[state_idx[0]]:
					connected_node_list.extend(env.entity2link[state_idx[0]][rel])
			connected_node_list = list(set(connected_node_list))
			if len(connected_node_list) == 0:
				neighbour_embeddings_list = [torch.zeros(1, embedding_dim).cuda() if USE_CUDA else torch.zeros(1, embedding_dim)]
			else:
				neighbour_embeddings_list = [floatTensor(env.entity2vec[connected_node_list])]

			action_probs, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden_this_time, cell_this_time, now_embedding, neighbour_embeddings_list)

			# Action Dropout
			dropout_action_probs = dropout(action_probs)
			# print(dropout_action_probs.shape)
			probability = np.squeeze(dropout_action_probs.cpu().detach().numpy())
			probability = probability / sum(probability)
			action_chosen = np.random.choice(np.arange(action_space), p = probability)

			reward, new_state, done = env.interact(state_idx, action_chosen)
			
			if reward == -1: # the action fails for this step
				state_batch_negative.append(state)
				lstm_input_batch_negative.append(lstm_input)
				hidden_batch_negative.append(hidden_this_time)
				cell_batch_negative.append(cell_this_time)
				action_batch_negative.append(action_chosen)
				now_embedding_batch_negative.append(now_embedding)
				neighbour_embeddings_list_batch_negative.append(neighbour_embeddings_list[0])

				# Force to choose a valid action to go forward
				try:
					valid_action_list = list(env.entity2link[state_idx[0]].keys()) 
					probability = probability[valid_action_list]
					# print("Line 288: ", sum(probability))
					probability = probability / sum(probability)
					# print("Line 288: ", probability)
					valid_action_chosen = np.random.choice(valid_action_list, p = probability)
					valid_reward, valid_new_state, valid_done = env.interact(state_idx, valid_action_chosen)

					reward, new_state, done = valid_reward, valid_new_state, valid_done

					if new_state == None:
						forward_node_list.append(env.entity2id_[sample[1]]) # The right tail entity
					else:
						forward_node_list.append(new_state[0])

					state_batch_positive.append(state)
					lstm_input_batch_positive.append(lstm_input)
					hidden_batch_positive.append(hidden_this_time)
					cell_batch_positive.append(cell_this_time)
					action_batch_positive.append(valid_action_chosen)
					now_embedding_batch_positive.append(now_embedding)
					neighbour_embeddings_list_batch_positive.append(neighbour_embeddings_list[0])

					hidden_this_time = hidden_new
					cell_this_time = cell_new

				except:
					print("Cannot find a valid action!")

			else: # the action find a path that can forward
				if new_state == None:
					forward_node_list.append(env.entity2id_[sample[1]]) # The right tail entity
				else:
					forward_node_list.append(new_state[0])

				state_batch_positive.append(state)
				lstm_input_batch_positive.append(lstm_input)
				hidden_batch_positive.append(hidden_this_time)
				cell_batch_positive.append(cell_this_time)
				action_batch_positive.append(action_chosen)
				now_embedding_batch_positive.append(now_embedding)
				neighbour_embeddings_list_batch_positive.append(neighbour_embeddings_list[0])

				hidden_this_time = hidden_new
				cell_this_time = cell_new

			new_state_vec = env.idx_state(new_state)
			episode.append(Transition(state = state_vec, action = action_chosen, next_state = new_state_vec, reward = reward))

			if done or t == max_steps:
				break

			state_idx = new_state
			
		# Discourage the agent when it chooses an invalid step
		if len(state_batch_negative) != 0 and done != 1:
			print ('Penalty to invalid steps:', len(state_batch_negative))
			
			policy_nn.zero_grad()
			action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_negative), depth = action_space))
			# action_prob = torch.stack(action_prob_batch_negative).squeeze(1)
			# print(state_batch_negative[0].shape)
			state = torch.cat(state_batch_negative, dim=0)
			lstm_input = torch.cat(lstm_input_batch_negative, dim=1)
			hidden = torch.cat(hidden_batch_negative, dim=1)
			cell = torch.cat(cell_batch_negative, dim=1)
			now_embedding = torch.cat(now_embedding_batch_negative, dim=0)
			action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_negative)
			# print(action_prob.shape)
			picked_action_prob = torch.masked_select(action_prob, action_mask)
			print(picked_action_prob)
			loss = -torch.sum(torch.log(picked_action_prob) * args.wrong_reward) # Reward for each invalid action is wrong_reward
			loss.backward(retain_graph=True)
			torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
			optimizer.step()
			
		print ('----- FINAL PATH -----')
		print ('\t'.join(env.path))
		print ('PATH LENGTH', len(env.path))
		print ('----- FINAL PATH -----')
		
		# If the agent success, do one optimization
		if done == 1:
			print ('Success')
			
			path_found_entity.append(path_clean(' -> '.join(env.path)))

			success += 1

			# Compute the reward for a successful episode.
			path_length = len(env.path)
			length_reward = 1/path_length
			global_reward = 1

			if len(path_found) != 0:
				path_found_embedding = [env.path_embedding(path.split(' -> ')) for path in path_found]
				curr_path_embedding = env.path_embedding(env.path_relations)
				path_found_embedding = np.reshape(path_found_embedding, (-1,embedding_dim))
				cos_sim = cosine_similarity(path_found_embedding, curr_path_embedding)
				diverse_reward = -np.mean(cos_sim)
				print ('diverse_reward', diverse_reward)
				total_reward = args.global_reward_weight * global_reward + args.length_reward_weight * length_reward + args.diverse_reward_weight * diverse_reward 
			else:
				total_reward = args.global_reward_weight * global_reward + (args.length_reward_weight + args.diverse_reward_weight) * length_reward
			path_found.add(' -> '.join(env.path_relations))

			# total_reward = 0.1*global_reward + 0.9*length_reward
			

			policy_nn.zero_grad()
			action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_positive), depth = action_space))
			state = torch.cat(state_batch_positive, dim=0)
			lstm_input = torch.cat(lstm_input_batch_positive, dim=1)
			hidden = torch.cat(hidden_batch_positive, dim=1)
			cell = torch.cat(cell_batch_positive, dim=1)
			now_embedding = torch.cat(now_embedding_batch_positive, dim=0)
			action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_positive)
			# print(action_prob.shape)
			picked_action_prob = torch.masked_select(action_prob, action_mask)
			loss = -torch.sum(torch.log(picked_action_prob) * total_reward) 
			# The reward for each step of a successful episode is total_reward
			loss.backward(retain_graph=True)
			torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
			optimizer.step()
		else:

			if (len(state_batch_positive) != 0):
				# reward shaping

				if args.reward_shaping_model == "TransH":
					# print("Enters TransH.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					norm = norm_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					tail = ent_embedding[forward_node_list]
					head_proj = head - np.sum(head * norm, axis=1, keepdims=True) * norm
					tail_proj = tail - np.sum(tail * norm, axis=1, keepdims=True) * norm
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "TransR":
					# print("Enters TransR.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					norm = norm_embedding[[env.relation2id_[relation.replace('_', ':')]]].squeeze(0)
					tail = ent_embedding[forward_node_list]
					head_proj = np.matmul(norm, head.T).T
					tail_proj = np.matmul(norm, tail.T).T
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "TransD":
					# print("Enters TransD.")
					head = ent_embedding[[env.entity2id_[sample[0]]]]
					head_norm = ent_norm_embedding[[env.entity2id_[sample[0]]]]
					tail = ent_embedding[forward_node_list]
					tail_norm = ent_norm_embedding[forward_node_list]
					rel_emb = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					rel_norm = rel_norm_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					head_proj = head + np.sum(head * head_norm, axis=1, keepdims=True) * rel_norm
					tail_proj = tail + np.sum(tail * tail_norm, axis=1, keepdims=True) * rel_norm
					scores = -np.sum(np.abs(head_proj + rel_emb - tail_proj), axis = 1)
					# print(scores)

				elif args.reward_shaping_model == "ProjE":
					# print("Enter ProjE.")
					h = ent_embedding[[env.entity2id_[sample[0]]]]
					r = rel_embedding[[env.relation2id_[relation.replace('_', ':')]]]
					ent_mat = np.transpose(ent_embedding)
					hr = h * simple_hr_combination_weights[:100] + r * simple_hr_combination_weights[100:]
					hrt_res = np.matmul(np.tanh(hr + combination_bias_hr), ent_mat)
					scores = hrt_res[0][forward_node_list]
					scores = torch.log(torch.sigmoid(torch.FloatTensor(scores))).numpy()
					# print(scores)

				elif args.reward_shaping_model == "ConvE":
					# print("Enters ConvE.")
					rel_id = TransE_to_ConvE_id_relation[env.relation2id_[relation.replace('_', ':')]]
					head_id = TransE_to_ConvE_id_entity[env.entity2id_[sample[0]]]
					tail_id = [TransE_to_ConvE_id_entity[elem] for elem in forward_node_list]

					bs = ConvE_model.batch_size
					x_middle, output = ConvE_model(longTensor([head_id] + [0] * (bs - 1)), longTensor([rel_id] * bs))

					scores = np.log(output[0][tail_id].detach().cpu().numpy() + 10 ** -30)
					# print(scores)

				else:
					head_embedding = ent_embedding[env.entity2id_[sample[0]]]
					query_embedding = rel_embedding[env.relation2id_[relation.replace('_', ':')]]
					tail_embedding = ent_embedding[forward_node_list]
					scores = -np.sum(np.abs(head_embedding + query_embedding - tail_embedding), axis = 1)

				policy_nn.zero_grad()
				action_mask = byteTensor(convert_to_one_hot(np.array(action_batch_positive), depth = action_space))
				state = torch.cat(state_batch_positive, dim=0)
				lstm_input = torch.cat(lstm_input_batch_positive, dim=1)
				hidden = torch.cat(hidden_batch_positive, dim=1)
				cell = torch.cat(cell_batch_positive, dim=1)
				now_embedding = torch.cat(now_embedding_batch_positive, dim=0)
				action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, neighbour_embeddings_list_batch_positive)
				# print(action_prob.shape)
				picked_action_prob = torch.masked_select(action_prob, action_mask)
				# print(picked_action_prob)
				loss = -torch.sum(torch.log(picked_action_prob) * floatTensor(scores) * args.useless_reward) 
				# The reward for each step of an unsuccessful episode is useless_reward
				loss.backward(retain_graph=True)
				torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
				optimizer.step()
			
			print ('Failed, Do one teacher guideline') # Force the agent to learn using a successful sample
			teacher_success_flag = False
			teacher_success_failed_times = 0
			while (not teacher_success_flag) and teacher_success_failed_times < 3:
				try:
					good_episodes = teacher(sample[0], sample[1], 1, env, graphpath, knowledge_base = kb, output_mode = 1) # Episode's ID instead of state!
					if len(good_episodes) == 0:
						teacher_success_failed_times += 1
					else:
						for item in good_episodes:
							if len(item) == 0:
								teacher_success_failed_times += 1
								break

							teacher_state_batch = []
							teacher_action_batch = []
							teacher_now_embedding_batch = []
							teacher_neighbour_embeddings_list_batch = []

							total_reward = 0.0*1 + 1*1/len(item)

							for t, transition in enumerate(item):
								teacher_state_batch.append(floatTensor(env.idx_state(transition.state)))
								teacher_action_batch.append(transition.action)
								teacher_now_embedding_batch.append(floatTensor(env.entity2vec[[transition.state[0]]]))

								connected_node_list = []
								if transition.state[0] in env.entity2link:
									for rel in env.entity2link[transition.state[0]]:
										connected_node_list.extend(env.entity2link[transition.state[0]][rel])
								connected_node_list = list(set(connected_node_list)) # Remove duplicates
								if len(connected_node_list) == 0:
									if USE_CUDA:
										neighbour_embeddings_list = torch.zeros(1, embedding_dim).cuda()
									else:
										neighbour_embeddings_list = torch.zeros(1, embedding_dim)

								else:
									neighbour_embeddings_list = floatTensor(env.entity2vec[connected_node_list])

								teacher_neighbour_embeddings_list_batch.append(neighbour_embeddings_list)
							   
							if (len(teacher_state_batch) != 0):
								hidden_this_time = torch.zeros(3, 1, hidden_dim)
								cell_this_time = torch.zeros(3, 1, hidden_dim)
								if USE_CUDA:
									hidden_this_time = hidden_this_time.cuda()
									cell_this_time = cell_this_time.cuda()

								state_batch_teacher = []
								lstm_input_batch_teacher = []
								hidden_batch_teacher = []
								cell_batch_teacher = []

								for idx, state_vec in enumerate(teacher_state_batch):
									state_vec = floatTensor(state_vec)
									state = torch.cat([state_vec, hidden_this_time[-1]], dim=1) # Only use the last layer's output
									lstm_input = state_vec.unsqueeze(1)
									now_embedding = teacher_now_embedding_batch[idx]
									teacher_neighbour_embeddings_list = [teacher_neighbour_embeddings_list_batch[idx]]
									action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden_this_time, cell_this_time, now_embedding, teacher_neighbour_embeddings_list)
									# print(action_prob.shape)
									hidden_this_time = hidden_new
									cell_this_time = cell_new

									state_batch_teacher.append(state)
									lstm_input_batch_teacher.append(lstm_input)
									hidden_batch_teacher.append(hidden_this_time)
									cell_batch_teacher.append(cell_this_time)

								now_embedding = torch.cat(teacher_now_embedding_batch, dim=0)

								policy_nn.zero_grad()
								action_mask = byteTensor(convert_to_one_hot(np.array(teacher_action_batch), depth = action_space))
								state = torch.cat(state_batch_teacher, dim=0)
								lstm_input = torch.cat(lstm_input_batch_teacher, dim=1)
								hidden = torch.cat(hidden_batch_teacher, dim=1)
								cell = torch.cat(cell_batch_teacher, dim=1)
								action_prob, lstm_output, hidden_new, cell_new = policy_nn(state, lstm_input, hidden, cell, now_embedding, teacher_neighbour_embeddings_list_batch)
								# print(action_prob.shape)
								picked_action_prob = torch.masked_select(action_prob, action_mask)
								loss = -torch.sum(torch.log(picked_action_prob) * args.teacher_reward) # The reward for each step of a teacher episode is teacher_reward
								loss.backward(retain_graph=True)
								torch.nn.utils.clip_grad_norm(policy_nn.parameters(), 0.2)
								optimizer.step()

								teacher_success_flag = True
							else:
								teacher_success_failed_times += 1
					
				except Exception as e:
					print ('Teacher guideline failed')
					teacher_success_failed_times += 10

		print ('Episode time: ', time.time() - start)
		print ('\n')
		print ("Retrain Success count: ", success)
		success_cnt_list.append(success)
	print ('Retrain Success percentage:', success/num_episodes)
	print (success_cnt_list)
	
	for path in path_found_entity: # Only successful paths
		rel_ent = path.split(' -> ')
		path_relation = []
		for idx, item in enumerate(rel_ent):
			if idx%2 == 0:
				path_relation.append(item)
		path_relation_found.append(' -> '.join(path_relation))
		
	relation_path_stats = collections.Counter(path_relation_found).items()
	relation_path_stats = sorted(relation_path_stats, key = lambda x:x[1], reverse=True) # Rank the paths according to their frequency.
	
	f = open(feature_stats, 'w')
	for item in relation_path_stats:
		f.write(item[0]+'\t'+str(item[1])+'\n')
	f.close()
	print ('Path stats saved')

	with open("logs/training/" + relation + ".out", 'a') as fw:
		fw.write(save_file_header + '_path_stats.txt' + '\n')
		fw.write('Retrain Success persentage: ' + str(success/num_episodes) + '\n')
		fw.write("Retrain success cnt list: ")
		fw.write(" ".join([str(elem) for elem in success_cnt_list]) + '\n')
		fw.write("\n")

	return 

def retrain(learning_rate = 0.001, weight_decay=0.005, dropout2 = 0.2, relation=None):
	print ('Start retraining')

	f = open(relationPath)
	training_pairs = f.readlines()
	f.close()

	model = policy_nn_lstm_attn(embedding_dim=embedding_dim, action_dim=action_space, attn_dim = args.attn_dim, initializer="xavier", dropout_rate=dropout2)
	# print(model)
	if USE_CUDA:
		model.cuda()
	print ("sl_policy restored")
	
	optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
	REINFORCE(training_pairs, policy_nn = model, optimizer = optimizer, num_episodes = args.num_episodes, relation=relation)
	torch.save(model, 'models/' + relation + '/' + save_file_header + ".ckpt")
	print ('Retrained model saved')

retrain(relation=relation, weight_decay = dynamic_weight_decay, dropout2=dynamic_dropout_rate2)

# Valid a reasoning path to a triple using BFS
def bfs_two(e1,e2,path,kb,kb_inv):
	start = 0
	end = len(path)
	left = set()
	right = set()
	left.add(e1)
	right.add(e2)

	left_path = []
	right_path = []
	while(start < end):
		left_step = path[start]
		left_next = set()
		right_step = path[end-1]
		right_next = set()

		if len(left) < len(right):
			left_path.append(left_step)
			start += 1
			for entity in left:
				try:
					for path_ in kb.getPathsFrom(entity):
						if path_.relation == left_step:
							left_next.add(path_.connected_entity)
				except Exception as e:
					print ('left', len(left))
					print (left)
					print ('not such entity')
					return True
			left = left_next

		else: 
			right_path.append(right_step)
			end -= 1
			for entity in right:
				try:
					for path_ in kb_inv.getPathsFrom(entity):
						if path_.relation == right_step:
							right_next.add(path_.connected_entity)
				except Exception as e:
					print ('right', len(right))
					print ('no such entity')
					return True
			right = right_next

	if len(right & left) != 0:
		return True 
	return False

def get_features(path_filename):
	"""
	stats = {}
	f = open(feature_stats)
	path_freq = f.readlines()
	f.close()
	for line in path_freq:
		path = line.split('\t')[0]
		num = int(line.split('\t')[1])
		stats[path] = num
	max_freq = np.max(stats.values())
	"""

	relation2id = {}
	f = open(relationId_path)
	content = f.readlines()
	f.close()
	for line in content:
		relation2id[line.split()[0]] = int(line.split()[1])

	useful_paths = []
	named_paths = []
	occurrence_paths = []
	f = open(path_filename)
	paths = f.readlines()
	f.close()

	print (len(paths))

	for line in paths:
		path = line.rstrip().split("\t")[0]
		occurrence = int(line.rstrip().split("\t")[1])

		length = len(path.split(' -> '))

		if length <= 10:
			pathIndex = []
			pathName = []
			relations = path.split(' -> ')

			for rel in relations:
				pathName.append(rel)
				rel_id = relation2id[rel]
				pathIndex.append(rel_id)
			useful_paths.append(pathIndex)
			named_paths.append(pathName)
			occurrence_paths.append(occurrence)

	"""
	max_freq = max(occurrence_paths)
	if max_freq > 1:
		zip_paths = list(zip(useful_paths, named_paths, occurrence_paths))
		zip_paths = [elem for elem in zip_paths if elem[2] > 1]
		useful_paths = [elem[0] for elem in zip_paths]
		named_paths = [elem[1] for elem in zip_paths]
		occurrence_paths = [elem[2] for elem in zip_paths]
	"""

	print ('How many paths used: ', len(useful_paths))
	return useful_paths, named_paths, occurrence_paths 
	# Paths' represented with index, string, and paths' occurence times

def fact_prediction_eval_logic():
	f1 = open(ent_id_path)
	f2 = open(rel_id_path)
	content1 = f1.readlines()
	content2 = f2.readlines()
	f1.close()
	f2.close()

	entity2id = {}
	relation2id = {}
	for line in content1:
		entity2id[line.split()[0]] = int(line.split()[1])

	for line in content2:
		relation2id[line.split()[0]] = int(line.split()[1])

	_, named_paths, occurrence_paths = get_features(feature_stats)

	length_weights = []
	for path in named_paths:
		weight = 1.0/len(path)
		length_weights.append(weight)
	length_weights = np.array(length_weights)

	"""
	path_weights = [elem / sum(occurrence_paths) for elem in occurrence_paths]
	path_weights = np.array(path_weights)
	"""
	kb = KB()
	kb_inv = KB()

	f = open(dataPath_ + '/graph.txt')
	kb_lines = f.readlines()
	f.close()

	for line in kb_lines:
		e1 = line.split()[0]
		rel = line.split()[1]
		e2 = line.split()[2]
		kb.addRelation(e1,rel,e2)
		kb_inv.addRelation(e2,rel,e1)

	f = open(test_data_path)
	test_data = f.readlines()
	f.close()
	test_pairs = []
	test_labels = []
	test_set = set()
	for line in test_data:
		e1 = line.split(',')[0].replace('thing$','')
		#e1 = '/' + e1[0] + '/' + e1[2:]
		e2 = line.split(',')[1].split(':')[0].replace('thing$','')
		#e2 = '/' + e2[0] + '/' + e2[2:]
		#if (e1 not in kb.entities) or (e2 not in kb.entities):
		#	continue
		test_pairs.append((e1,e2))
		label = 1 if line[-2] == '+' else 0
		test_labels.append(label)

	scores_rl = []

	print ('How many queries: ', len(test_pairs))
	for idx, sample in enumerate(test_pairs):
		print ('Query No.%d of %d' % (idx, len(test_pairs)))

		features = []
		for path in named_paths:
			features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))
		features = features * length_weights
		score_rl = sum(features)
		scores_rl.append(score_rl)

	rank_stats_rl = list(zip(scores_rl, test_labels))
	rank_stats_rl.sort(key = lambda x:x[0], reverse=True)

	correct = 0
	ranks = []
	for idx, item in enumerate(rank_stats_rl):
		if item[1] == 1:
			correct += 1
			ranks.append(correct/(1.0+idx))
	ap3 = np.mean(ranks)
	# print(len(ranks))
	print ('RL: ', ap3)

	with open("logs/fact_prediction/" + relation + ".out", 'a') as fw:
		fw.write(filename + '\n')
		fw.write('RL fact prediction: ' + str(ap3) + '\n')
		fw.write("\n")

fact_prediction_eval_logic() # Fact Prediction

def train(kb, kb_inv, named_paths):
	f = open(dataPath_ + '/train.pairs')
	train_data = f.readlines()
	f.close()
	train_pairs = []
	train_labels = []

	for line in train_data:
		e1 = line.split(',')[0].replace('thing$','')
		e2 = line.split(',')[1].split(':')[0].replace('thing$','')
		if (e1 not in kb.entities) or (e2 not in kb.entities):
			continue
		train_pairs.append((e1,e2))
		label = 1 if line[-2] == '+' else 0
		train_labels.append(label)

	training_features = []
	for sample in train_pairs:
		feature = []
		for path in named_paths: # Each path is a feature
			feature.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv))) # True or false (0 or 1)
		training_features.append(feature)

	model = Sequential()
	input_dim = len(named_paths)
	model.add(Dense(1, activation='sigmoid' ,input_dim=input_dim))
	model.compile(optimizer = 'rmsprop', loss='binary_crossentropy', metrics=['accuracy'])
	model.fit(np.array(training_features), np.array(train_labels), nb_epoch=300, batch_size=128)
	return model

def link_prediction_evaluate_logic():
	kb = KB()
	kb_inv = KB()

	f = open(dataPath_ + '/graph.txt')
	kb_lines = f.readlines()
	f.close()

	for line in kb_lines:
		e1 = line.split()[0]
		rel = line.split()[1]
		e2 = line.split()[2]
		kb.addRelation(e1,rel,e2)
		kb_inv.addRelation(e2,rel,e1)

	_, named_paths, occurrence_paths = get_features(feature_stats)

	model = train(kb, kb_inv, named_paths)


	f = open(dataPath_ + '/sort_test.pairs')
	test_data = f.readlines()
	f.close()
	test_pairs = [] 
	test_labels = []
	# queries = set()
	for line in test_data:
		e1 = line.split(',')[0].replace('thing$','')
		# e1 = '/' + e1[0] + '/' + e1[2:]
		e2 = line.split(',')[1].split(':')[0].replace('thing$','')
		# e2 = '/' + e2[0] + '/' + e2[2:]
		if (e1 not in kb.entities) or (e2 not in kb.entities):
			continue
		test_pairs.append((e1,e2))
		label = 1 if line[-2] == '+' else 0
		test_labels.append(label)

	aps = []
	query = test_pairs[0][0]
	y_true = []
	y_score = []

	score_all = []

	for idx, sample in enumerate(test_pairs):
		#print 'query node: ', sample[0], idx
		if sample[0] == query:
			features = []
			for path in named_paths:
				features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))

			#features = features*path_weights

			score = model.predict(np.reshape(features, [1,-1]))
			#score = np.sum(features)

			score_all.append(score[0])
			y_score.append(score)
			y_true.append(test_labels[idx])
		else: # Calculate the statistics for this query, and reset all the counters for a new query.
			query = sample[0]
			count = list(zip(y_score, y_true))
			count.sort(key = lambda x:x[0], reverse=True)
			ranks = []
			correct = 0
			for idx_, item in enumerate(count):
				if item[1] == 1:
					correct +=  1
					ranks.append(correct/(1.0+idx_)) 
					# idx_ is the rank for the right element
					#break
			if len(ranks) ==0:
				aps.append(0)
			else:
				aps.append(np.mean(ranks))
			#print np.mean(ranks)
			# if len(aps) % 10 == 0:
			# 	print 'How many queries:', len(aps)
			# 	print np.mean(aps)
			y_true = []
			y_score = []
			features = []
			for path in named_paths:
				features.append(int(bfs_two(sample[0], sample[1], path, kb, kb_inv)))

			#features = features*path_weights
			#score = np.inner(features, path_weights)
			#score = np.sum(features)
			score = model.predict(np.reshape(features,[1,-1]))

			score_all.append(score[0])
			y_score.append(score)
			y_true.append(test_labels[idx])
			# print y_score, y_true

	# Calculate the statistics for the last query
	count = list(zip(y_score, y_true))
	count.sort(key = lambda x:x[0], reverse=True)
	ranks = []
	correct = 0
	for idx_, item in enumerate(count):
		if item[1] == 1:
			correct +=  1
			ranks.append(correct/(1.0+idx_))
	aps.append(np.mean(ranks))

	score_label = list(zip(score_all, test_labels))
	score_label_ranked = sorted(score_label, key = lambda x:x[0], reverse=True)

	mean_ap = np.mean(aps)
	# print(len(aps))
	print ('RL MAP: ', mean_ap)

	with open("logs/link_prediction/" + relation + ".out", 'a') as fw:
		fw.write(filename + '\n')
		fw.write('RL link prediction MAP: ' + str(mean_ap) + '\n')
		fw.write('\n')

link_prediction_evaluate_logic() # Link Prediction 

